In this work I will explain predictions obtained from a Random Forest model and a Logistic Regression model. I will use
Local Interpretable Model-agnostic Explanations (LIME) implementations from the Python framework lime. Dataset
used is The Heart Attack Analysis dataset
(source).
Dataset attributes:
I have preprocessed the dataset by one hot encoding categorical features.
Correlation matrix shows the highest influence on the output from thall_2, thalachh and slp_2.
I have run the explainer on the same samples for 3 different seeds. We can see that the seed does not have a big impact
on the results both on the sample 203 and sample 246. There are differences in importance but not significant.
We can see that label thall_2 (Thalium Stress Test result value 2 from [0, 3]) is the most important feature for both
samples which might indicate high success rate of this particular medical test. Another important feature is halach
(heart rate) which for non-expert is easiest to correlate with chance of heart attack.
I compare LIME from lime package to SHAP from dalex package. thall_2 is the most important feature for both
samples and both explainers which is not surprising as the thall_2 feature has the highest correlation with output
of all features. Both explainers have similar result, but what is worth
noting is that for sample 203 SHAP marked oldpeak as one of the most important features which is not included in top
10 most important features from LIME. Additionally LIME ranked caa_x features high which is not that important in
SHAP explanations.
LIME assigns highly different importance between features for Random Forest model and Logistic Regression model. Nonetheless all shared top features for both models have the same attribution sign (positive/negative). For Logistic Regression model contrary to Random Forest model LIME does not find the 'thalachh` feature important, this might be due to the fact that the data is not normalized and Regression model does not adapt well to the variable with values in this range - [71, 202].
%%capture
%pip install dalex jinja2 kaleido lime numpy nbformat pandas plotly torch scikit-learn shap
import dalex as dx
import lime
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import shap
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import RidgeClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
rng = np.random.default_rng(0)
TARGET_COLUMN = "output"
df = pd.read_csv("heart.csv")
df.describe()
Shuffling the data, extracting target column and one hot encoding categorical columns..
df = df.sample(frac=1, random_state=0).reset_index(drop=True)
y = df[[TARGET_COLUMN]]
x = df.drop(TARGET_COLUMN, axis=1)
categorical_cols = ["sex", "cp", "fbs", "restecg", "exng", "slp", "caa", "thall"]
numerical_cols = list(set(x.columns) - set(categorical_cols))
x = pd.get_dummies(x, columns=categorical_cols, drop_first=True)
n_columns = len(x.columns)
categorical_cols, numerical_cols
Not being an expert in cardiology I plot correlation matrix to have a view how each column impacts the target.
corr_df = x.copy()
corr_df[TARGET_COLUMN] = y
corr = corr_df.corr("pearson")
corr.style.background_gradient(cmap='coolwarm')
model = RandomForestClassifier(random_state=0).fit(x, y)
accuracy_score(y, model.predict(x))
sample_ids = [42, 81, 203, 246]
df.iloc[sample_ids]
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
training_data=x.values,
feature_names=x.columns,
mode="classification",
)
lime_explanations = [lime_explainer.explain_instance(
data_row=x.iloc[i],
predict_fn=lambda d: model.predict_proba(d)
) for i in sample_ids]
lime_explanations[0].as_list()
_ = lime_explanations[0].as_pyplot_figure()
_ = lime_explanations[0].show_in_notebook()
def plot_explanation(lime_explanation, filename, title):
_ = lime_explanation.show_in_notebook()
_ = lime_explanation.as_pyplot_figure()
plt.title(title)
plt.savefig(filename, bbox_inches='tight')
def explain(seed):
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
training_data=x.values,
feature_names=x.columns,
mode="classification",
random_state=seed
)
lime_explanations = [lime_explainer.explain_instance(
data_row=x.iloc[i],
predict_fn=lambda d: model.predict_proba(d)
) for i in sample_ids]
for id, lime_explanation in enumerate(lime_explanations):
plot_explanation(lime_explanation, f"imgs/lime_{id}_seed_{seed}.png", f"Random Forest, sample={sample_ids[id]}, seed={seed}")
for seed in range(3):
explain(seed)
lr_clf = RidgeClassifier(random_state=0).fit(x, y.squeeze())
accuracy_score(y, lr_clf.predict(x))
def lr_predict_func(d):
pred = lr_clf.decision_function(d)
res = 1 / (1 + np.exp(-pred))
return np.array([1 - res, res]).T
lr_lime_explanations = [lime_explainer.explain_instance(
data_row=x.iloc[i],
predict_fn=lr_predict_func
) for i in sample_ids]
for id, lime_explanation in enumerate(lr_lime_explanations):
plot_explanation(lime_explanation, f"imgs/lr_lime_{id}.png", f"Logistic Regression, sample={sample_ids[id]}")